/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Einsum, util } from '@tensorflow/tfjs-core';
import { multiply } from './Multiply';
import { reshape } from './Reshape';
import { sum } from './Sum';
import { transpose } from './Transpose';
export function einsum(args) {
    const { inputs, backend, attrs } = args;
    const { equation } = attrs;
    const tensors = inputs;
    const { allDims, summedDims, idDims } = backend_util.decodeEinsumEquation(equation, tensors.length);
    backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors);
    const { path, steps } = backend_util.getEinsumComputePath(summedDims, idDims);
    const nSteps = steps.length;
    let out = null;
    let numDimsRemaining = allDims.length;
    const tensorsToDispose = [];
    for (let i = 0; i < nSteps; ++i) {
        for (const idTerm of steps[i]) {
            const { permutationIndices: perm, expandDims: dimsToExpand } = backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
            let x;
            if (backend_util.isIdentityPermutation(perm)) {
                x = tensors[idTerm];
            }
            else {
                x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
                tensorsToDispose.push(x);
            }
            const targetShape = x.shape.slice();
            for (let k = 0; k < dimsToExpand.length; ++k) {
                targetShape.splice(dimsToExpand[k], 0, 1);
            }
            if (!util.arraysEqual(x.shape, targetShape)) {
                x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
                tensorsToDispose.push(x);
            }
            if (out === null) {
                out = x;
            }
            else {
                // tslint:disable-next-line: no-unnecessary-type-assertion
                out = multiply({ inputs: { a: x, b: out }, backend });
                tensorsToDispose.push(out);
            }
        }
        if (i < nSteps - 1) {
            if (path[i] >= 0) {
                out = sum({
                    inputs: { x: out },
                    backend,
                    attrs: {
                        axis: path[i] - (allDims.length - numDimsRemaining),
                        keepDims: false
                    }
                });
                tensorsToDispose.push(out);
            }
            numDimsRemaining--;
        }
    }
    // Clean up intermediate tensors.
    for (const tensorInfo of tensorsToDispose) {
        if (tensorInfo === out) {
            continue;
        }
        backend.disposeIntermediateTensorInfo(tensorInfo);
    }
    return out;
}
export const einsumConfig = {
    kernelName: Einsum,
    backendName: 'cpu',
    kernelFunc: einsum
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRWluc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9FaW5zdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxNQUFNLEVBQTJFLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBSTFJLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDcEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsTUFBTSxVQUFVLE1BQU0sQ0FDbEIsSUFBeUU7SUFFM0UsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxRQUFRLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDekIsTUFBTSxPQUFPLEdBQUcsTUFBa0IsQ0FBQztJQUVuQyxNQUFNLEVBQUMsT0FBTyxFQUFFLFVBQVUsRUFBRSxNQUFNLEVBQUMsR0FDL0IsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFFBQVEsRUFBRSxPQUFPLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDaEUsWUFBWSxDQUFDLG1CQUFtQixDQUFDLE9BQU8sQ0FBQyxNQUFNLEVBQUUsTUFBTSxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQ2xFLE1BQU0sRUFBQyxJQUFJLEVBQUUsS0FBSyxFQUFDLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFVBQVUsRUFBRSxNQUFNLENBQUMsQ0FBQztJQUU1RSxNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQzVCLElBQUksR0FBRyxHQUFvQixJQUFJLENBQUM7SUFDaEMsSUFBSSxnQkFBZ0IsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO0lBQ3RDLE1BQU0sZ0JBQWdCLEdBQWlCLEVBQUUsQ0FBQztJQUMxQyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQy9CLEtBQUssTUFBTSxNQUFNLElBQUksS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQzdCLE1BQU0sRUFBQyxrQkFBa0IsRUFBRSxJQUFJLEVBQUUsVUFBVSxFQUFFLFlBQVksRUFBQyxHQUN0RCxZQUFZLENBQUMsb0JBQW9CLENBQUMsZ0JBQWdCLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUM7WUFDeEUsSUFBSSxDQUFhLENBQUM7WUFDbEIsSUFBSSxZQUFZLENBQUMscUJBQXFCLENBQUMsSUFBSSxDQUFDLEVBQUU7Z0JBQzVDLENBQUMsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDLENBQUM7YUFDckI7aUJBQU07Z0JBQ0wsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsTUFBTSxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFDLEVBQUMsQ0FBQyxDQUFDO2dCQUN0RSxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUM7YUFDMUI7WUFDRCxNQUFNLFdBQVcsR0FBYSxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssRUFBRSxDQUFDO1lBQzlDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxZQUFZLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO2dCQUM1QyxXQUFXLENBQUMsTUFBTSxDQUFDLFlBQVksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7YUFDM0M7WUFFRCxJQUFJLENBQUMsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLFdBQVcsQ0FBQyxFQUFFO2dCQUMzQyxDQUFDLEdBQUcsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxXQUFXLEVBQUMsRUFBQyxDQUFDLENBQUM7Z0JBQ2pFLGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQzthQUMxQjtZQUNELElBQUksR0FBRyxLQUFLLElBQUksRUFBRTtnQkFDaEIsR0FBRyxHQUFHLENBQUMsQ0FBQzthQUNUO2lCQUFNO2dCQUNMLDBEQUEwRDtnQkFDMUQsR0FBRyxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLEdBQUcsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFlLENBQUM7Z0JBQ2hFLGdCQUFnQixDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQzthQUM1QjtTQUNGO1FBQ0QsSUFBSSxDQUFDLEdBQUcsTUFBTSxHQUFHLENBQUMsRUFBRTtZQUNsQixJQUFJLElBQUksQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDLEVBQUU7Z0JBQ2hCLEdBQUcsR0FBRyxHQUFHLENBQUM7b0JBQ1IsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBQztvQkFDaEIsT0FBTztvQkFDUCxLQUFLLEVBQUU7d0JBQ0wsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLEdBQUcsZ0JBQWdCLENBQUM7d0JBQ25ELFFBQVEsRUFBRSxLQUFLO3FCQUNoQjtpQkFDRixDQUFDLENBQUM7Z0JBQ0gsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDO2FBQzVCO1lBQ0QsZ0JBQWdCLEVBQUUsQ0FBQztTQUNwQjtLQUNGO0lBRUQsaUNBQWlDO0lBQ2pDLEtBQUssTUFBTSxVQUFVLElBQUksZ0JBQWdCLEVBQUU7UUFDekMsSUFBSSxVQUFVLEtBQUssR0FBRyxFQUFFO1lBQ3RCLFNBQVM7U0FDVjtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxVQUFVLENBQUMsQ0FBQztLQUNuRDtJQUVELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFlBQVksR0FBaUI7SUFDeEMsVUFBVSxFQUFFLE1BQU07SUFDbEIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLE1BQStCO0NBQzVDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsLCBFaW5zdW0sIEVpbnN1bUF0dHJzLCBFaW5zdW1JbnB1dHMsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29yLCBUZW5zb3JJbmZvLCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kQ1BVfSBmcm9tICcuLi9iYWNrZW5kX2NwdSc7XG5cbmltcG9ydCB7bXVsdGlwbHl9IGZyb20gJy4vTXVsdGlwbHknO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzdW19IGZyb20gJy4vU3VtJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBlaW5zdW0oXG4gICAgYXJnczoge2lucHV0czogRWluc3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IEVpbnN1bUF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtlcXVhdGlvbn0gPSBhdHRycztcbiAgY29uc3QgdGVuc29ycyA9IGlucHV0cyBhcyBUZW5zb3JbXTtcblxuICBjb25zdCB7YWxsRGltcywgc3VtbWVkRGltcywgaWREaW1zfSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZGVjb2RlRWluc3VtRXF1YXRpb24oZXF1YXRpb24sIHRlbnNvcnMubGVuZ3RoKTtcbiAgYmFja2VuZF91dGlsLmNoZWNrRWluc3VtRGltU2l6ZXMoYWxsRGltcy5sZW5ndGgsIGlkRGltcywgdGVuc29ycyk7XG4gIGNvbnN0IHtwYXRoLCBzdGVwc30gPSBiYWNrZW5kX3V0aWwuZ2V0RWluc3VtQ29tcHV0ZVBhdGgoc3VtbWVkRGltcywgaWREaW1zKTtcblxuICBjb25zdCBuU3RlcHMgPSBzdGVwcy5sZW5ndGg7XG4gIGxldCBvdXQ6IFRlbnNvckluZm98bnVsbCA9IG51bGw7XG4gIGxldCBudW1EaW1zUmVtYWluaW5nID0gYWxsRGltcy5sZW5ndGg7XG4gIGNvbnN0IHRlbnNvcnNUb0Rpc3Bvc2U6IFRlbnNvckluZm9bXSA9IFtdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IG5TdGVwczsgKytpKSB7XG4gICAgZm9yIChjb25zdCBpZFRlcm0gb2Ygc3RlcHNbaV0pIHtcbiAgICAgIGNvbnN0IHtwZXJtdXRhdGlvbkluZGljZXM6IHBlcm0sIGV4cGFuZERpbXM6IGRpbXNUb0V4cGFuZH0gPVxuICAgICAgICAgIGJhY2tlbmRfdXRpbC5nZXRFaW5zdW1QZXJtdXRhdGlvbihudW1EaW1zUmVtYWluaW5nLCBpZERpbXNbaWRUZXJtXSk7XG4gICAgICBsZXQgeDogVGVuc29ySW5mbztcbiAgICAgIGlmIChiYWNrZW5kX3V0aWwuaXNJZGVudGl0eVBlcm11dGF0aW9uKHBlcm0pKSB7XG4gICAgICAgIHggPSB0ZW5zb3JzW2lkVGVybV07XG4gICAgICB9IGVsc2Uge1xuICAgICAgICB4ID0gdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiB0ZW5zb3JzW2lkVGVybV19LCBiYWNrZW5kLCBhdHRyczoge3Blcm19fSk7XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaCh4KTtcbiAgICAgIH1cbiAgICAgIGNvbnN0IHRhcmdldFNoYXBlOiBudW1iZXJbXSA9IHguc2hhcGUuc2xpY2UoKTtcbiAgICAgIGZvciAobGV0IGsgPSAwOyBrIDwgZGltc1RvRXhwYW5kLmxlbmd0aDsgKytrKSB7XG4gICAgICAgIHRhcmdldFNoYXBlLnNwbGljZShkaW1zVG9FeHBhbmRba10sIDAsIDEpO1xuICAgICAgfVxuXG4gICAgICBpZiAoIXV0aWwuYXJyYXlzRXF1YWwoeC5zaGFwZSwgdGFyZ2V0U2hhcGUpKSB7XG4gICAgICAgIHggPSByZXNoYXBlKHtpbnB1dHM6IHt4fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogdGFyZ2V0U2hhcGV9fSk7XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaCh4KTtcbiAgICAgIH1cbiAgICAgIGlmIChvdXQgPT09IG51bGwpIHtcbiAgICAgICAgb3V0ID0geDtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIC8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgICAgICAgb3V0ID0gbXVsdGlwbHkoe2lucHV0czoge2E6IHgsIGI6IG91dH0sIGJhY2tlbmR9KSBhcyBUZW5zb3JJbmZvO1xuICAgICAgICB0ZW5zb3JzVG9EaXNwb3NlLnB1c2gob3V0KTtcbiAgICAgIH1cbiAgICB9XG4gICAgaWYgKGkgPCBuU3RlcHMgLSAxKSB7XG4gICAgICBpZiAocGF0aFtpXSA+PSAwKSB7XG4gICAgICAgIG91dCA9IHN1bSh7XG4gICAgICAgICAgaW5wdXRzOiB7eDogb3V0fSxcbiAgICAgICAgICBiYWNrZW5kLFxuICAgICAgICAgIGF0dHJzOiB7XG4gICAgICAgICAgICBheGlzOiBwYXRoW2ldIC0gKGFsbERpbXMubGVuZ3RoIC0gbnVtRGltc1JlbWFpbmluZyksXG4gICAgICAgICAgICBrZWVwRGltczogZmFsc2VcbiAgICAgICAgICB9XG4gICAgICAgIH0pO1xuICAgICAgICB0ZW5zb3JzVG9EaXNwb3NlLnB1c2gob3V0KTtcbiAgICAgIH1cbiAgICAgIG51bURpbXNSZW1haW5pbmctLTtcbiAgICB9XG4gIH1cblxuICAvLyBDbGVhbiB1cCBpbnRlcm1lZGlhdGUgdGVuc29ycy5cbiAgZm9yIChjb25zdCB0ZW5zb3JJbmZvIG9mIHRlbnNvcnNUb0Rpc3Bvc2UpIHtcbiAgICBpZiAodGVuc29ySW5mbyA9PT0gb3V0KSB7XG4gICAgICBjb250aW51ZTtcbiAgICB9XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyh0ZW5zb3JJbmZvKTtcbiAgfVxuXG4gIHJldHVybiBvdXQ7XG59XG5cbmV4cG9ydCBjb25zdCBlaW5zdW1Db25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRWluc3VtLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGVpbnN1bSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=